library(ggplot2)
library(RColorBrewer)
library(ISLR)
library(latex2exp)
theme_set(theme_bw(base_size = 14))
cols <- brewer.pal(9, "Set1")

Beyond the linear model

Let’s first try to extend the linear model with the two approaches seen in class: polynomial regression and step functions.

Polynomial regression

We have seen polynomial regression in a previous R laborarory. Here we use the same strategy: remember that you can include polynomial terms with the syntax I() or with the syntax poly().

data(Auto)

fit <- lm(mpg ~ poly(horsepower, degree = 5), data = Auto)
# equivalent to:
# fit <- lm(mpg ~ horsepower + I(horsepower^2) + I(horsepower^3) + 
#                 I(horsepower^4) + I(horsepower^5), data = Auto)
xgrid <- seq(min(Auto$horsepower) - 20, max(Auto$horsepower) + 20, length.out = 1000)

pred <- predict(fit, newdata = list(horsepower = xgrid), se = T)
y_hat <- as.numeric(pred$fit)
sd_f <- as.numeric(pred$se.fit)

par(mar=c(4,4,2,2), family = 'serif')
plot(Auto$horsepower, Auto$mpg, pch = 16, main = 'Polynomial regression', xlab = 'Horsepower', ylab = 'miles per gallon')
polygon(c(xgrid,rev(xgrid)), c(y_hat - 2 * sd_f,rev(y_hat + 2 * sd_f)), 
        col = alpha(cols[3], 0.5), border = alpha(cols[3], 0.5))
lines(xgrid, y_hat, lwd = 2, col = 1)

Step functions

To use step functions, we need to create dummy variables that tell us to which interval each observation belongs to! To do so, we can use the cut() function!

K <- 20
breaks <- c(-Inf, min(Auto$horsepower) + (1:(K-1)) * diff(range(Auto$horsepower))/K, +Inf)
cut(Auto$horsepower, breaks = breaks)
  [1] (129,138]   (156,166]   (147,156]   (147,156]   (138,147]   (193,202]  
  [7] (212,221]   (212,221]   (221, Inf]  (184,193]   (166,175]   (156,166]  
 [13] (147,156]   (221, Inf]  (92,101]    (92,101]    (92,101]    (82.8,92]  
 [19] (82.8,92]   (-Inf,55.2] (82.8,92]   (82.8,92]   (92,101]    (110,120]  
 [25] (82.8,92]   (212,221]   (193,202]   (202,212]   (184,193]   (82.8,92]  
 [31] (82.8,92]   (92,101]    (92,101]    (101,110]   (92,101]    (82.8,92]  
 [37] (92,101]    (156,166]   (175,184]   (147,156]   (147,156]   (175,184]  
 [43] (166,175]   (175,184]   (101,110]   (64.4,73.6] (92,101]    (82.8,92]  
 [49] (82.8,92]   (82.8,92]   (64.4,73.6] (73.6,82.8] (64.4,73.6] (64.4,73.6]
 [55] (55.2,64.4] (64.4,73.6] (92,101]    (73.6,82.8] (-Inf,55.2] (82.8,92]  
 [61] (82.8,92]   (156,166]   (175,184]   (147,156]   (147,156]   (147,156]  
 [67] (202,212]   (147,156]   (156,166]   (184,193]   (92,101]    (147,156]  
 [73] (129,138]   (138,147]   (147,156]   (110,120]   (73.6,82.8] (82.8,92]  
 [79] (64.4,73.6] (82.8,92]   (82.8,92]   (92,101]    (73.6,82.8] (82.8,92]  
 [85] (175,184]   (147,156]   (138,147]   (129,138]   (147,156]   (193,202]  
 [91] (147,156]   (156,166]   (147,156]   (212,221]   (221, Inf]  (175,184]  
 [97] (101,110]   (92,101]    (92,101]    (82.8,92]   (92,101]    (-Inf,55.2]
[103] (147,156]   (166,175]   (166,175]   (175,184]   (92,101]    (82.8,92]  
[109] (64.4,73.6] (92,101]    (82.8,92]   (82.8,92]   (101,110]   (82.8,92]  
[115] (138,147]   (221, Inf]  (-Inf,55.2] (73.6,82.8] (82.8,92]   (110,120]  
[121] (147,156]   (101,110]   (120,129]   (175,184]   (92,101]    (92,101]   
[127] (92,101]    (64.4,73.6] (73.6,82.8] (64.4,73.6] (73.6,82.8] (92,101]   
[133] (101,110]   (101,110]   (138,147]   (147,156]   (147,156]   (138,147]  
[139] (147,156]   (82.8,92]   (64.4,73.6] (73.6,82.8] (-Inf,55.2] (55.2,64.4]
[145] (73.6,82.8] (73.6,82.8] (73.6,82.8] (92,101]    (92,101]    (64.4,73.6]
[151] (92,101]    (101,110]   (64.4,73.6] (64.4,73.6] (166,175]   (138,147]  
[157] (147,156]   (147,156]   (101,110]   (101,110]   (101,110]   (92,101]   
[163] (101,110]   (101,110]   (129,138]   (73.6,82.8] (82.8,92]   (92,101]   
[169] (73.6,82.8] (92,101]    (64.4,73.6] (92,101]    (92,101]    (64.4,73.6]
[175] (82.8,92]   (92,101]    (82.8,92]   (92,101]    (110,120]   (-Inf,55.2]
[181] (82.8,92]   (73.6,82.8] (82.8,92]   (73.6,82.8] (82.8,92]   (138,147]  
[187] (147,156]   (120,129]   (147,156]   (92,101]    (101,110]   (73.6,82.8]
[193] (82.8,92]   (-Inf,55.2] (55.2,64.4] (64.4,73.6] (-Inf,55.2] (92,101]   
[199] (73.6,82.8] (101,110]   (92,101]    (64.4,73.6] (64.4,73.6] (73.6,82.8]
[205] (64.4,73.6] (101,110]   (147,156]   (82.8,92]   (101,110]   (120,129]  
[211] (175,184]   (138,147]   (129,138]   (147,156]   (64.4,73.6] (73.6,82.8]
[217] (55.2,64.4] (92,101]    (64.4,73.6] (138,147]   (101,110]   (138,147]  
[223] (129,138]   (101,110]   (101,110]   (92,101]    (92,101]    (175,184]  
[229] (166,175]   (184,193]   (147,156]   (73.6,82.8] (82.8,92]   (73.6,82.8]
[235] (82.8,92]   (55.2,64.4] (82.8,92]   (64.4,73.6] (73.6,82.8] (92,101]   
[241] (101,110]   (101,110]   (-Inf,55.2] (64.4,73.6] (-Inf,55.2] (64.4,73.6]
[247] (55.2,64.4] (101,110]   (138,147]   (138,147]   (101,110]   (92,101]   
[253] (82.8,92]   (82.8,92]   (92,101]    (82.8,92]   (101,110]   (82.8,92]  
[259] (101,110]   (120,129]   (138,147]   (156,166]   (138,147]   (138,147]  
[265] (64.4,73.6] (92,101]    (92,101]    (73.6,82.8] (92,101]    (101,110]  
[271] (82.8,92]   (92,101]    (101,110]   (120,129]   (110,120]   (129,138]  
[277] (64.4,73.6] (64.4,73.6] (110,120]   (82.8,92]   (82.8,92]   (82.8,92]  
[283] (101,110]   (129,138]   (129,138]   (129,138]   (129,138]   (147,156]  
[289] (138,147]   (120,129]   (147,156]   (64.4,73.6] (64.4,73.6] (73.6,82.8]
[295] (73.6,82.8] (73.6,82.8] (120,129]   (64.4,73.6] (82.8,92]   (64.4,73.6]
[301] (64.4,73.6] (64.4,73.6] (64.4,73.6] (82.8,92]   (110,120]   (110,120]  
[307] (82.8,92]   (73.6,82.8] (55.2,64.4] (64.4,73.6] (64.4,73.6] (82.8,92]  
[313] (82.8,92]   (82.8,92]   (82.8,92]   (73.6,82.8] (82.8,92]   (73.6,82.8]
[319] (82.8,92]   (73.6,82.8] (64.4,73.6] (101,110]   (64.4,73.6] (-Inf,55.2]
[325] (-Inf,55.2] (64.4,73.6] (64.4,73.6] (64.4,73.6] (64.4,73.6] (55.2,64.4]
[331] (129,138]   (92,101]    (82.8,92]   (64.4,73.6] (82.8,92]   (82.8,92]  
[337] (82.8,92]   (101,110]   (82.8,92]   (55.2,64.4] (55.2,64.4] (55.2,64.4]
[343] (64.4,73.6] (64.4,73.6] (55.2,64.4] (64.4,73.6] (55.2,64.4] (64.4,73.6]
[349] (64.4,73.6] (73.6,82.8] (73.6,82.8] (73.6,82.8] (92,101]    (73.6,82.8]
[355] (73.6,82.8] (73.6,82.8] (110,120]   (120,129]   (101,110]   (101,110]  
[361] (82.8,92]   (82.8,92]   (82.8,92]   (82.8,92]   (82.8,92]   (82.8,92]  
[367] (82.8,92]   (82.8,92]   (82.8,92]   (73.6,82.8] (64.4,73.6] (64.4,73.6]
[373] (55.2,64.4] (64.4,73.6] (82.8,92]   (73.6,82.8] (64.4,73.6] (64.4,73.6]
[379] (64.4,73.6] (64.4,73.6] (101,110]   (82.8,92]   (82.8,92]   (110,120]  
[385] (92,101]    (82.8,92]   (82.8,92]   (82.8,92]   (-Inf,55.2] (82.8,92]  
[391] (73.6,82.8] (73.6,82.8]
20 Levels: (-Inf,55.2] (55.2,64.4] (64.4,73.6] (73.6,82.8] ... (221, Inf]

This is treated like a categorical variable (factor)! So lm() knows that it will have to create \(K-1\) coefficients, plus an intercept. Remember how to interpret these? It is the difference with respect to the baseline.

fit <- lm(mpg ~ cut(horsepower, breaks = breaks), data = Auto)

pred <- predict(fit, newdata = list(horsepower = xgrid), se = T)
y_hat <- as.numeric(pred$fit)
sd_f <- as.numeric(pred$se.fit)

par(mar=c(4,4,2,2), family = 'serif')
plot(Auto$horsepower, Auto$mpg, pch = 16, xlab = 'Horsepower', ylab = 'miles per gallon', 
     main = 'Step regression')
polygon(c(xgrid,rev(xgrid)), c(y_hat - 2 * sd_f,rev(y_hat + 2 * sd_f)), 
        col = alpha(cols[3], 0.5), border = alpha(cols[3], 0.5))
lines(xgrid, y_hat, lwd = 2, col = 1)


Splines

We finally talk about how to use splines in practice.

B-splines

It is easy to expand the covariate \(X\) on the B-spline basis, using the command bs(..., knots = ..., degree = ...).

library(splines)
fit <- lm(mpg ~ bs(horsepower, knots = 125, degree = 3), data = Auto)
pred <- predict(fit, newdata = list(horsepower = xgrid), se = T)

par(mar=c(4,4,2,2), family = 'serif')
plot(Auto$horsepower, Auto$mpg, pch = 16, xlab = 'Horsepower', ylab = 'miles per gallon',  main = 'Spline regression')
polygon(c(xgrid,rev(xgrid)), c(pred$fit - 2 * pred$se,rev(pred$fit + 2 * pred$se)), 
        col = alpha(cols[3], 0.5), border = alpha(cols[3], 0.5))
lines(xgrid, pred$fit, lwd = 2, col = 1)
abline(v = 125, lty = 2, lwd = 2, col = 2)

The red dashed line represents the only internal node.

Let us now show the B-spline basis: first, we can plot the B-spline basis of degree \(1\).

rm(list=ls())

xgrid <- seq(0, 10, length.out = 1000)

par(mar=c(4,4,2,2), family = 'serif')
matplot(xgrid, as.matrix(bs(xgrid, knots = c(2.5, 5, 7.5), intercept = T, degree = 1)), type = 'l', lty = 1, lwd = 2, xlab = 'x', ylab = bquote(b[k](x)), xaxt = 'n')
abline(v = c(min(xgrid), 2.5, 5, 7.5, max(xgrid)), lty = 2)
axis(1, at = c(2.5, 5, 7.5), labels = c(TeX("$\\xi_1$"), TeX("$\\xi_2$"), TeX("$\\xi_3$")))

Let us now plot the B-spline basis of degree \(3\).

par(mar=c(4,4,2,2), family = 'serif')
matplot(xgrid, as.matrix(bs(xgrid, knots = c(2.5, 5, 7.5), intercept = T, degree = 3)), type = 'l', lty = 1, lwd = 2, xlab = 'x', ylab = bquote(b[k](x)), xaxt = 'n')
abline(v = c(min(xgrid), 2.5, 5, 7.5, max(xgrid)), lty = 2)
axis(1, at = c(2.5, 5, 7.5), labels = c(TeX("$\\xi_1$"), TeX("$\\xi_2$"), TeX("$\\xi_3$")))

Natural splines

Let’s compare the fit of B-splines with the one of natural splines. Remember that natural splines add constraints at the boundaries: they assume that, regardless of the degree used for the splines, the curve continues linearly outside of the observed range of \(X\).

rm(list=ls())
cols <- brewer.pal(9, "Set1")

data(Auto)
xgrid <- seq(min(Auto$horsepower) - 20, max(Auto$horsepower) + 20, length.out = 100)

fit <- lm(mpg ~ bs(horsepower, df = 5), data = Auto)
pred <- predict(fit, newdata = list(horsepower = xgrid), se = T)

par(mar=c(4,4,2,2), family = 'serif')
plot(Auto$horsepower, Auto$mpg, pch = 16, xlab = 'Horsepower', ylab = 'miles per gallon', main = '')
polygon(c(xgrid,rev(xgrid)), c(pred$fit - 2 * pred$se,rev(pred$fit + 2 * pred$se)), 
        col = alpha(cols[1], 0.5), border = alpha(cols[1], 0.5))
lines(xgrid, pred$fit, lwd = 2, col = cols[1])

fit <- lm(mpg ~ ns(horsepower, df = 5), data = Auto)
pred <- predict(fit, newdata = list(horsepower = xgrid), se = T)

polygon(c(xgrid,rev(xgrid)), c(pred$fit - 2 * pred$se,rev(pred$fit + 2 * pred$se)), 
        col = alpha(cols[2], 0.5), border = alpha(cols[2], 0.5))
lines(xgrid, pred$fit, lwd = 2, col = cols[2])
legend('topright', legend = c('B-splines','Natural splines'), col = cols[1:2], lwd = 2)

Smoothing splines

Smoothing splines arise from a different optimization problem (penalizing the second derivative, i.e. the roughness of the function).

fit <- lm(mpg ~ bs(horsepower, df = 5), data = Auto)
pred <- predict(fit, newdata = list(horsepower = xgrid), se = T)

par(mar=c(4,4,2,2), family = 'serif')
plot(Auto$horsepower, Auto$mpg, pch = 16, xlab = 'Horsepower', ylab = 'miles per gallon', main = '')
polygon(c(xgrid,rev(xgrid)), c(pred$fit - 2 * pred$se,rev(pred$fit + 2 * pred$se)), 
        col = alpha(cols[1], 0.5), border = alpha(cols[1], 0.5))
lines(xgrid, pred$fit, lwd = 2, col = cols[1])

fit <- lm(mpg ~ ns(horsepower, df = 5), data = Auto)
pred <- predict(fit, newdata = list(horsepower = xgrid), se = T)

polygon(c(xgrid,rev(xgrid)), c(pred$fit - 2 * pred$se,rev(pred$fit + 2 * pred$se)), 
        col = alpha(cols[2], 0.5), border = alpha(cols[2], 0.5))
lines(xgrid, pred$fit, lwd = 2, col = cols[2])

fit <- smooth.spline(Auto$horsepower, y = Auto$mpg, df = 5)
pred <- predict(fit, x = xgrid, se = T)
lines(xgrid, pred$y, lwd = 2, col = cols[3])

legend('topright', legend = c('B-splines','Natural splines','Smoothing splines'), col = cols[1:3], lwd = 2)

Choosing the degrees of freedom

Usually, we let R decide the location of the knots, we choose a degree (e.g. cubic splines), and we can then choose the degrees of freedom (that are related to the number of knots) via cross validation.

Let us generate some data with mean function equal to \(\sin(2x)\)

library(plotrix)

rm(list=ls())
cols <- brewer.pal(9, "Set1")

set.seed(123)
n <- 400
X <- runif(n, 0, 10)
Y <- rnorm(n, sin(2*X), 0.5)
data <- data.frame(X = X, Y = Y)
xgrid <- seq(0, 10, length.out = 100)

par(mar=c(4,4,2,2), family = 'serif')
plot(X, Y, pch = 16, xlab = '', ylab = '', main = '')

Now we can use 10-fold cross validation to estimate the test error.

K <- 10
d_max <- 30
folds <- cut(sample(1:n), breaks = K, labels = FALSE)
train_err <- array(NA, dim = d_max - 2)
cv_err <- array(NA, dim = c(K, d_max - 2))

for (d in 3:d_max){
  fit <- lm(Y ~ bs(X, df = d), data = data)
  pred <- predict(fit, newdata = data)

  train_err[d - 2] <- mean((pred - Y)^2)
  for(k in 1:K){
    idx_tr <- which(folds != k)
    idx_ts <- which(folds == k)

    fit <- lm(Y ~ bs(X, df = d), data = data, subset = idx_tr)
    pred <- as.numeric(predict(fit, newdata = data[idx_ts,]))
    cv_err[k,d - 2] <- mean((pred - Y[idx_ts])^2)
  }
}
cv_MSE <- colMeans(cv_err)
sd_cv_MSE <- apply(cv_err, 2, sd)/sqrt(K)

plotCI(3:d_max, cv_MSE, uiw = sd_cv_MSE, pch = 16, col = cols[1], xlab = 'df',
       ylab = 'MSE')
points(3:d_max, train_err, pch = 16)
abline(v = 2 + which.min(cv_MSE), lwd = 2, lty = 2)

And show the optimal fit, along with two sub-optimal fits.

# Fit the optimal spline model
fit <- lm(Y ~ bs(X, df = (3:d_max)[which.min(cv_MSE)]), data = data)
pred <- predict(fit, newdata = list(X = xgrid), se = T)

par(mar=c(4,4,2,2), family = 'serif')
plot(X, Y, pch = 16, xlab = '', ylab = '', main = '')
polygon(c(xgrid,rev(xgrid)), c(pred$fit - 2 * pred$se,rev(pred$fit + 2 * pred$se)),
        col = alpha(cols[1], 0.5), border = alpha(cols[1], 0.5))
lines(xgrid, pred$fit, lwd = 2, col = cols[1])

fit <- lm(Y ~ bs(X, df = (3:d_max)[2]), data = data)
pred <- predict(fit, newdata = list(X = xgrid), se = T)
polygon(c(xgrid,rev(xgrid)), c(pred$fit - 2 * pred$se,rev(pred$fit + 2 * pred$se)),
        col = alpha(cols[2], 0.5), border = alpha(cols[2], 0.5))
lines(xgrid, pred$fit, lwd = 2, col = cols[2])

fit <- lm(Y ~ bs(X, df = d_max), data = data)
pred <- predict(fit, newdata = list(X = xgrid), se = T)
polygon(c(xgrid,rev(xgrid)), c(pred$fit - 2 * pred$se,rev(pred$fit + 2 * pred$se)),
        col = alpha(cols[3], 0.5), border = alpha(cols[3], 0.5))
lines(xgrid, pred$fit, lwd = 2, col = cols[3])

legend('topright', legend = c('Optimal','Underfitted','Overfitted'), col = cols[1:3], lwd = 2)


Generalized additive models

Regression

library(gam)
rm(list=ls())

data(Auto)

fit_gam <- mgcv::gam(mpg ~ s(horsepower) + 
                       s(acceleration) + 
                       factor(year), data = Auto)

par(mfrow=c(1,3), mar = c(4,4,2,2), family = 'serif')
mgcv::plot.gam(fit_gam, all.terms=TRUE)

Classification

rm(list=ls())
cols <- brewer.pal(9, "Set1")

data("Default")

fit_gam <- mgcv::gam(default ~ s(income) +
                       s(balance) +
                       factor(student), 
                     family = binomial, 
                     data = Default)

par(mfrow=c(1,3), mar = c(4,4,2,2), family = 'serif')
mgcv::plot.gam(fit_gam, all.terms=TRUE)